use aya_ebpf::{
    helpers::bpf_probe_read_user_str_bytes,
    macros::{map, uprobe, uretprobe},
    maps::PerfEventArray,
    programs::ProbeContext,
    EbpfContext,
};
use core::ffi::c_char;
use pg_ferret_shared::{
    Event, FunctionName, PostgresEntry, QueryText, ThreadId, FUNC_NAME_LEN, QUERY_TEXT_LEN,
};

#[map]
/// The map for storing the events that are emitted by the eBPF program. This map
/// is read by the userspace collector to reconstruct the trace of function calls
/// in the Postgres backend. It is implemented as a ring buffer. If the userspace
/// collector does not read the events fast enough, the ring buffer will fill up
/// and the eBPF program will stop emitting events. This is a safety feature to
/// prevent the eBPF program from consuming too much memory in the kernel.
pub static mut EVENTS: PerfEventArray<Event> = PerfEventArray::with_max_entries(1024, 0);

#[uprobe]
/// This is a special hook for the exec_simple_query function in Postgres. This
/// function is one of the main top level functions that is called when a client
/// sends a query to the Postgres backend. This one is used when the client sends
/// a simple query, which is a query that does not use the extended query protocol.
pub fn exec_simple_query_entry(ctx: ProbeContext) -> u32 {
    let query = query_text(&ctx);
    let thread_id = ctx.tgid();
    let pid = ctx.pid();
    let event = PostgresEntry::ExecSimpleQuery(query, thread_id, pid);
    submit_entry(ctx, event)
}

#[uretprobe]
/// This is the return hook for the exec_simple_query function in Postgres. It is
/// used to indicate that the exec_simple_query function has returned.
pub fn exec_simple_query_return(ctx: ProbeContext) -> u32 {
    let thread_id = ctx.tgid();
    submit_return(ctx, "exec_simple_query", thread_id)
}

#[uprobe]
/// This is a special hook for the exec_parse_message function in Postgres. This
/// function is one of the main top level functions that is called when a client
/// sends a query to the Postgres backend. This one is used when the client sends
/// a query that uses the extended query protocol. It is the only function
/// that is called in the extended query protocol that takes a query string as
/// an argument.
pub fn exec_parse_message_entry(ctx: ProbeContext) -> u32 {
    let query = query_text(&ctx);
    let thread_id = ctx.tgid();
    let pid = ctx.pid();
    let event = PostgresEntry::ExecParseMessage(query, thread_id, pid);
    submit_entry(ctx, event)
}

#[uretprobe]
/// This is the return hook for the exec_parse_message function in Postgres. It is
/// used to indicate that the exec_parse_message function has returned.
pub fn exec_parse_message_return(ctx: ProbeContext) -> u32 {
    let thread_id = ctx.tgid();
    submit_return(ctx, "exec_parse_message", thread_id)
}

#[uprobe]
/// This is a special hook for the exec_bind_message function in Postgres. This
/// is the second function that is called in the extended query protocol.
pub fn exec_bind_message_entry(ctx: ProbeContext) -> u32 {
    let thread_id = ctx.tgid();
    let pid = ctx.pid();
    let event = PostgresEntry::ExecBindMessage(thread_id, pid);
    submit_entry(ctx, event)
}

#[uretprobe]
/// This is the return hook for the exec_bind_message function in Postgres. It is
/// used to indicate that the exec_bind_message function has returned.
pub fn exec_bind_message_return(ctx: ProbeContext) -> u32 {
    let thread_id = ctx.tgid();
    submit_return(ctx, "exec_bind_message", thread_id)
}

#[uprobe]
/// This is a special hook for the exec_execute_message function in Postgres. This
/// is the third function that is called in the extended query protocol.
pub fn exec_execute_message_entry(ctx: ProbeContext) -> u32 {
    let thread_id = ctx.tgid();
    let pid = ctx.pid();
    let event = PostgresEntry::ExecExecuteMessage(thread_id, pid);
    submit_entry(ctx, event)
}

#[uretprobe]
/// This is the return hook for the exec_execute_message function in Postgres. It is
/// used to indicate that the exec_execute_message function has returned.
pub fn exec_execute_message_return(ctx: ProbeContext) -> u32 {
    let thread_id = ctx.tgid();
    submit_return(ctx, "exec_execute_message", thread_id)
}

/// Helper function for reading a query text from the eBPF stack and converting
/// it into a fixed size array that can be passed to the userspace collector.
fn query_text(ctx: &ProbeContext) -> QueryText {
    let arg0: *const c_char = ctx.arg(0).unwrap();
    let mut buf = [0u8; QUERY_TEXT_LEN];
    unsafe { bpf_probe_read_user_str_bytes(arg0 as *const u8, &mut buf).unwrap() };
    buf
}

/// Helper function for submitting an entry event to the userspace collector.
/// This function is used by the special hooks for the top level functions in
/// Postgres, as well as the autogenerated hooks for the lower level functions.
pub fn submit_entry(ctx: ProbeContext, event: PostgresEntry) -> u32 {
    // SAFETY: The events map is a mutable static so it's unsafe to access it
    // unless you can guarantee only one thread is accessing it at a time.
    // We can guarantee that here because the eBPF program is single threaded.
    unsafe {
        EVENTS.output(&ctx, &Event::Entry(event), 0);
    }
    0
}

/// Helper function for submitting a return event to the userspace collector.
/// This function is used by the special hooks for the top level functions in
/// Postgres, as well as the autogenerated hooks for the lower level functions.
pub fn submit_return(ctx: ProbeContext, func: &str, thread_id: ThreadId) -> u32 {
    let func = str_to_func(func);
    // SAFETY: The events map is a mutable static so it's unsafe to access it
    // unless you can guarantee only one thread is accessing it at a time.
    // We can guarantee that here because the eBPF program is single threaded.
    unsafe {
        EVENTS.output(&ctx, &Event::Return(func, thread_id), 0);
    }
    0
}

/// Helper function for converting a string into a fixed size array that can be
/// passed to the userspace collector.
pub fn str_to_func(str: &str) -> FunctionName {
    let mut func_name_bytes = [0u8; FUNC_NAME_LEN];
    let func_bytes = str.as_bytes();
    let len = if func_bytes.len() < FUNC_NAME_LEN {
        func_bytes.len()
    } else {
        FUNC_NAME_LEN
    };
    func_name_bytes[..len].copy_from_slice(&func_bytes[..len]);
    func_name_bytes
}
